import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import random
from torch import nn
import torch.nn.functional  as F
import argparse

class CNN(nn.Module):
    def __init__(self, m=50, d=1000, q=1):
        super(CNN, self).__init__()

        self.q = q
        self.W = torch.nn.Parameter(torch.randn(d, m))
        self.W.requires_grad = True
        self.m = m

        nn.init.normal_(self.W, std=0.001)


    def act(self,input):
        return torch.pow(input, 2)


    def forward(self, x1, x2, verbose=False):
        # x1, x2 [n, d, n_eps]
        z1 = torch.einsum('bkj,ki->bij', self.act(torch.einsum('bij,ik->bkj', x1, self.W)), self.W.T)/np.sqrt(self.m) #+ x1
        z2 = torch.einsum('bkj,ki->bij', self.act(torch.einsum('bij,ik->bkj', x2, self.W)), self.W.T)/np.sqrt(self.m) #+ x1
        return z1, z2



def prepare_data():
    train_y = torch.cat((torch.ones(int(n_train/2)), -torch.ones(int(n_train/2))))
    test_y = torch.cat((torch.ones(int(n_test/2)), -torch.ones(int(n_test/2))))

    feature1 = torch.zeros(d, 1)
    feature1[0] = mu

    feature2 = torch.zeros(d, 1)
    feature2[1] = mu

    train_x1 = torch.matmul( (1+train_y.unsqueeze(0).T)/2, feature1.T) + \
                torch.matmul( (1-train_y.unsqueeze(0).T)/2, feature2.T)
    test_x1 = torch.matmul( (1+test_y.unsqueeze(0).T)/2, feature1.T) + \
                torch.matmul( (1-test_y.unsqueeze(0).T)/2, feature2.T)
    train_x2 = torch.randn(n_train, d)
    test_x2 = torch.randn(n_test, d)

    return train_x1, train_x2, train_y, test_x1, test_x2, test_y, feature1.squeeze(), feature2.squeeze()


def set_seed(seed: int):
    """set the seed
    """
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # For multi-GPU.
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def forward_diffusion(x, t, eps):
    return np.exp(-t) * x.unsqueeze(-1) + np.sqrt(1 - np.exp(-2*t)) * eps


def parse_args(args):
    parser = argparse.ArgumentParser()

    parser.add_argument("--time", type=float, default=0.1)
    parser.add_argument("--mu", type=int, default=15)

    args = parser.parse_args(args)

    return args


if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)

    args = parse_args(None)

    seed = 100
    n_train = 30
    n_test = 3000
    d = 1000
    n_epoch = 50000
    m = 20
    mu = args.mu
    time = args.time

    n_eps = 2000

    alpha_t = np.exp(-time)
    beta_t = np.sqrt(1 - np.exp(-2*time))

    print(alpha_t, beta_t)

    set_seed(seed)

    model = CNN(m=m, d=d).to(device)
    train_x1, train_x2, train_y, test_x1, test_x2, test_y, feature1, feature2 = prepare_data()

    train_x1 = train_x1.to(device)
    train_x2 = train_x2.to(device)
    feature1 = feature1.to(device)
    feature2 = feature2.to(device)

    sample_size = n_train
    data_loader = DataLoader(TensorDataset(
        train_x1,
        train_x2,
        train_y
    ), batch_size=int(n_train), shuffle=False)

    optimizer = torch.optim.SGD(model.parameters(), lr = 0.5)

    train_loss_values = []
    test_loss_values = []

    noise_memorization = np.zeros(( m, n_train, n_epoch))
    feature_learning1 = np.zeros(( m,  n_epoch))
    feature_learning2 = np.zeros(( m,  n_epoch))
    w_inner_product = np.zeros((m, m, n_epoch))

    for ep in range(n_epoch):
        train_loss = 0
        model.train()
        for sample_x1, sample_x2, sample_y, in data_loader:

            sample_x1 = sample_x1
            sample_x2 = sample_x2

            eps1 = torch.randn(n_train, d, n_eps, device=device)
            eps2 = torch.randn(n_train, d, n_eps, device=device)

            x1 = forward_diffusion(sample_x1, time, eps1)
            x2 = forward_diffusion(sample_x2, time, eps2)

            optimizer.zero_grad()
            z1, z2 = model.forward(x1, x2)

            loss = F.mse_loss(z1, eps1) + F.mse_loss(z2, eps2)

            loss.backward()
            optimizer.step()
            train_loss += sample_size * loss.item()


        model.eval()
        with torch.no_grad():
            feature_learning1[:, ep] = (torch.matmul(model.W.T, feature1)).cpu().detach().numpy()
            feature_learning2[:, ep] = (torch.matmul(model.W.T, feature2)).cpu().detach().numpy()
            noise_memorization[:,:,ep] = (torch.matmul(model.W.T, train_x2.T)).cpu().detach().numpy()

            w_inner_product[:,:,ep] = torch.matmul(model.W.T, model.W).cpu().detach().numpy()

            train_loss /= n_train
            train_loss_values.append(train_loss)

        print(f'[{ep+1}|{n_epoch}] train_loss={train_loss:0.5e}')

    noise_seris = np.max(np.abs(noise_memorization), axis=0)
    feature_seris1 = np.max(np.abs(feature_learning1), axis=0)
    feature_seris2 = np.max(np.abs(feature_learning2), axis=0)

    checkpoint = {
        'model_state_dict': model.state_dict(),
        'loss': train_loss_values,
        'noise': noise_seris,
        'feature1': feature_seris1,
        'feature2': feature_seris2
    }

    torch.save(checkpoint, f'syn_diff_{time}_{mu}.pth')